import torch
import matplotlib as plt
from torchvision import datasets,transforms

transform = transforms.Compose([transforms.ToTensor()])

trainset = datasets.MNIST('./.pytorch/MNIST_data', download = True, train = True, transform = transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=64, shuffle= True)

dataiter = iter(trainloader)
images, labels = dataiter.next()
print(images.shape)
print(labels.shape)
plt.imshow(images[1].numpy().squeeze(), cmap = 'Greys_r')